#include "stdafx.h"
#include "LazyTest.h"


// Notes: 
// 1. you need to initialize data member in all constructors.
// 2. notice that data members might be NULL if initialized by default constructor
// 3. Derived to Base convert is implemented by a member template and special constructor.(to pass ref count)
template <class T>
class SmartPtr
{
public:
	// Default Constructor
	SmartPtr():m_pNative(NULL), m_pRefCount(NULL)
	{
	}

	// Constructor
	SmartPtr(T* p)
	{
		m_pNative = p;
		m_pRefCount = p != NULL ? new uint32_t(1) : NULL;
	}

	// Copy Constructor
	SmartPtr(const SmartPtr& rhs)
	{

		m_pNative = rhs.m_pNative;
		m_pRefCount = rhs.m_pRefCount;
		if(m_pRefCount) (*m_pRefCount)++;
	
	}

	// Assign Operator
	SmartPtr& operator = (const SmartPtr& rhs)
	{
		// self check
		if(m_pNative == rhs.m_pNative) 
			return *this;

		if(m_pNative != NULL) // default constructed
			Release();

		// assign
		m_pNative = rhs.m_pNative;
		m_pRefCount = rhs.m_pRefCount;
		(*m_pRefCount)++;

		return *this;
	}

	// Destructor
	~SmartPtr()
	{
		if(m_pRefCount) // not default constructed
			Release();
	}

	// Operator ->
	T* operator -> ()
	{
		return m_pNative;
	}

	// Avoid implicit assignment
	void ReBind(T* p)
	{
		if(m_pNative == p) return;

		if(m_pRefCount == NULL)
		{
			m_pNative = p;
			m_pRefCount = m_pNative ? new uint32_t(1) : NULL;
			return;
		}
		else
		{
			if(--(*m_pRefCount) == 0)
			{
				delete m_pNative;
				if(p) *m_pRefCount = 1; // recycle the count
				else {delete m_pRefCount; m_pRefCount = NULL;}
			}
			else
			{
				m_pRefCount = p ? new uint32_t(1) : NULL;
			}

			m_pNative = p;

		}
	}

	// Get Native Pointer - Avoid implicit convert
	T* Native() const {return m_pNative;}
	uint32_t RefCount() const {return m_pRefCount ? *m_pRefCount : 0;}


	// SmartPtr assignment/copy with different T
	// Help constructor
	SmartPtr(T* p, uint32_t* pRef)
	{
		m_pNative = p;
		m_pRefCount = pRef;
		if(m_pRefCount) (*m_pRefCount)++;
	}

	template <class U>
	operator SmartPtr<U> ()
	{
		SmartPtr<U> spU(this->m_pNative, this->m_pRefCount);
		return spU;
	}

private:
	void Release()
	{
		if(--(*m_pRefCount) == 0)
		{
			delete m_pNative;
			delete m_pRefCount;
		}
	}
private:
	T* m_pNative;
	uint32_t* m_pRefCount;
	
};

// Sample Class
class Vehicle
{
public:
	virtual ~Vehicle(){}
	void turn_left(){}
	void turn_right(){}
};

class Car: public Vehicle
{
public:
	void add_fuel(){}
};


// Test Cases
TESTCASE(test_defaultconstructor)
{
	// Constructor
	SmartPtr<Vehicle> spV1;
	SmartPtr<Vehicle> spV2;
	SmartPtr<Car> spCar1;
	SmartPtr<Car> spCar2;
	ASSERT_TRUE(!spV1.Native()  && !spV2.Native() && !spCar1.Native()  && !spCar2.Native()) ;

	// Copy
	SmartPtr<Vehicle> spV3(spV1);
	ASSERT_TRUE(!spV3.Native());

	// Copy: Derived to Base
	SmartPtr<Vehicle> spV4(spCar1);
	ASSERT_TRUE(!spV4.Native());


	// Assignment
	spV1 = spV2;
	spCar1 = spCar2;
	ASSERT_TRUE(!spV1.Native() && !spCar1.Native());

	// Assignment: Derived to Base
	spV1 = spCar1;
	spV2 = spCar2;
	ASSERT_TRUE(!spV1.Native()  && !spV2.Native());

	return true;
}



TESTCASE(test_constructor)
{
	// Ref count
	SmartPtr<Vehicle> spV1(new Vehicle());
	ASSERT_TRUE(spV1.RefCount() == 1);

	{
		SmartPtr<Vehicle> spV2(spV1);
		ASSERT_TRUE(spV1.RefCount() == 2);

		SmartPtr<Vehicle> spV3;
		ASSERT_TRUE(spV3.RefCount() == 0);

		spV3 = spV2;
		ASSERT_TRUE(spV1.RefCount() == 3);
	}

	ASSERT_TRUE(spV1.RefCount() == 1);

	// Pass in a null pointer
	SmartPtr<Vehicle> spV4(NULL);
	ASSERT_TRUE(spV4.Native() == NULL);
	ASSERT_TRUE(spV4.RefCount() == 0);

	return true;
}

// Test Cases
TESTCASE(test_assignment)
{
	SmartPtr<Vehicle> spV1(new Vehicle());
	SmartPtr<Vehicle> spV2(new Vehicle());
	SmartPtr<Car> spCar1(new Car());

	spV1 = spV2;
	
	ASSERT_TRUE(spV1.Native() == spV2.Native());
	ASSERT_TRUE(spV1.RefCount() == 2 && spV2.RefCount() == 2);

	spV1 = spCar1;
	ASSERT_TRUE(spV1.Native() == spCar1.Native());
	ASSERT_TRUE(spV1.RefCount() == 2 && spCar1.RefCount() == 2);

	SmartPtr<Vehicle> spV5(new Vehicle());
	

	return true;
}

TESTCASE(test_rebind)
{
	SmartPtr<Vehicle> spV1(new Vehicle());
	spV1.ReBind(NULL);
	ASSERT_TRUE(spV1.Native() == NULL);
	ASSERT_TRUE(spV1.RefCount() == 0);

	Car* pCar = new Car();
	spV1.ReBind(pCar);
	ASSERT_TRUE(spV1.Native() == pCar);
	ASSERT_TRUE(spV1.RefCount() == 1);

	SmartPtr<Vehicle> spV2 = spV1;
	spV1.ReBind(new Car());

	ASSERT_TRUE(spV1.RefCount() == 1);

	return true;

}






int main()
{
	RUN_ALL_CASES();
}